import numpy as np
import os
import json
import csv
import random
from pymongo import MongoClient
htlist=['holidayvibes','gamenight','happyholidays','carsoftiktok','fallguysmoments','homecooked','veteransday','youwantmore','interiordesign','coldweather','wildanimals','mycostume','meleaving','mypfp','catchphrases','watchmegrow','holidaycrafts','growupwithme','clingypet','happyhanukkah','lunarnewyear',
       'tabletop','comfortfood','selfimprovement','2021affirmations','perfectmatch','givingszn','holidaycountdown','bakingszn','holidaymusic','familyimpression','inkdrawing','WeekendVibes','recordsday','productivity','smallbusiness'
       ,'falldiy','whenwewereyounger','yellow','ComingOfAge','artmas','gaminglife','gamingsetup','hellowinter','planttiktok','housetour','neonshadow','homeoffice','raisedby','foodtiktok','valentinesday','yougotthis','stemlife','makeitvogue']

htlist=[ 'oneyourthere', 'sfxmakeup', 'tiktokfood', 'worldseries', 'halloweenishere', 'animation', 'youwantmore', 'halloweenlook', 'nativefamily', 'welldone', 'happyhalloween', 'OhNo', 'holidaytiktok', 'rnbvibes',
        'myrecommendation', 'fallfashion', 'myhobby', 'fanedit', 'motivationmonday', 'wip', 'veteransday', 'whereilive', 'bekind', 'onhold', 'diwali', 'RoomTour', 'foodtiktok', 'holidaysourway',
        'nbadraft', 'ourtype', 'nonuancenovember', 'theatrekids', 'needtoknow', 'graphicdesign', 'familyrecipe', 'onlinedating', 'givingthanks', 'readysetshop', 'diceroll',
        'homecooked', 'easydiy', 'syndouch1', 'howbizarre', 'ImAGhost', 'personalfinance', 'thinkingabout', 'RatatouilleMusical', 'ootd', 'holidaydecor',
        'goodmorning', 'festivefashion', 'haventseen', 'lovestory', 'wishlist', 'YearOnTikTok', 'nbaisback', 'perfectgifts', 'holidaytreats', 'wrappinggifts', 'feliznavidad', 'timewarpjump', 'christmas2020', 'MyHaul', 'cozyathome', 'myhaul', 'winterfit',
        'withouttellingme', 'hyperfixated', 'whatilearned', 'joedizzle', 'WordsOfWisdom', 'bye2020', 'rareaesthetic', 'welcome2021', 'dailyvlog', 'easyrecipe', 'mystyle', 'Bye2020', 'myroutine', 'problemstop', 'gamergoals', 'projectcar', 'homemade', 'inlove', 'GreenScreenScan', 'tortillatrend', 'NFLplayoffs', 'FitnessRoutine', 'IsThisAvailable',
        'wee', 'weirdpets', 'fitnessroutine', 'isthisavailable', 'moneytok', 'nflplayoffs', 'healthycooking', 'nhlfaceoff', 'groupchat', 'NHLFaceOff', 'winterfashion', 'skincare101', 'zodiacsign', 'homeimprovement', 'seashanty', 'cleantok', 'visionboard', 'mlkday', 'joblife', 'foodie', 'timewarpwaterfall', 'plantparent', 'WinterMagic',
        'notaperfectperson', 'tiktoktutorial', 'cocinando', 'OlympicsCountdown', 'couplethings', 'meditation101', 'winterbeautytips', 'tiktokdiy', 'typing', 'roundofapplause', 'xgamesmode', 'feelinggood', 'RoyalRumble', 'emophase', 'favmoriteslippers', 'wintersports', 'makeblackhistory', 'relationshipstorytime', 'albumcover', 'stepbystep',
        'fetapasta', 'womeninsports', 'healthyheart', 'imbusyrightnow', 'beautyhacks', 'tiktoktailgate', 'puppybowl', 'superbowllv', 'melaninmagic', 'coversforlovers', 'kissyourpet', 'womeninstem', 'galentinesday', 'valentinesday', 'loveyourinsecurities', 'tiktokfashionmonth', 'blackcreatives', 'colddays',
        'mifamilia', 'careeradvice', 'stopasianhate', 'perfectdrink', 'snowstorm', 'carhacks', 'homeproject', 'blackandproud', 'dramaticmoments', 'bakedoats', 'homecook', 'laughingduet', 'yoga101', 'somethingyoulearned', 'upcycling', 'fantheory', 'tiktokfitness', 'gamingtiktok', 'seitan', 'glasspainting', 'whenwomenwin', 'science101', 'thriftflip', 'Lifestyle',
        'tiktokwildlifeday', 'dayandnight', 'ontherunway', 'homediy', 'nbaallstar']


htlist=['makeitvogue', 'wip', 'haventseen', 'holidayvibes', 'ootd', 'bekind', 'personalfinance', 'cozyathome', 'gamenight', 'happyholidays', 'veteransday', 'RoomTour', 'meleaving', 'yellow', 'theatrekids', 'carsoftiktok', 'yougotthis', 'ImAGhost', 'holidaytiktok', 'mycostume', 'halloweenlook', 'happyhalloween', 'homecooked', 'welldone', 'motivationmonday', 'coldweather', 'thinkingabout', 'wildanimals', 'youwantmore', 'mypfp', 'catchphrases', 'nonuancenovember', 'ComingOfAge', 'neonshadow', 'ourtype', 'fanedit', 'whenwewereyounger', 'needtoknow', 'watchmegrow', 'cleantok', 'graphicdesign', 'valentinesday', 'holidaycrafts', 'gaminglife', 'growupwithme', 'readysetshop', 'holidaysourway', 'gamingsetup', 'onlinedating', 'raisedby', 'holidaycountdown', 'recordsday', '2021affirmations', 'fallguysmoments', 'myhobby', 'housetour', 'tiktokfood', 'homeoffice', 'happyhanukkah', 'whereilive', 'myrecommendation', 'clingypet', 'inkdrawing', 'worldseries', 'smallbusiness', 'lunarnewyear', 'planttiktok', 'selfimprovement', 'productivity', 'comfortfood', 'tabletop', 'animation', 'artmas', 'cocinando', 'bakingszn', 'easydiy', 'diceroll', 'familyimpression', 'rnbvibes', 'festivefashion', 'stemlife', 'interiordesign', 'holidaymusic', 'falldiy', 'hellowinter', 'foodtiktok', 'holidaydecor', 'nbadraft', 'halloweenishere', 'christmas2020', 'howbizarre', 'sfxmakeup', 'WeekendVibes', 'givingthanks']
htlist=['wip', 'haventseen', 'ootd', 'bekind', 'personalfinance', 'cozyathome', 'RoomTour', 'theatrekids', 'ImAGhost', 'holidaytiktok', 'halloweenlook', 'happyhalloween', 'welldone', 'motivationmonday', 'thinkingabout', 'nonuancenovember', 'ourtype', 'fanedit', 'needtoknow', 'cleantok', 'graphicdesign', 'readysetshop', 'holidaysourway', 'onlinedating', 'myhobby', 'tiktokfood', 'whereilive', 'myrecommendation', 'worldseries', 'animation', 'cocinando', 'easydiy', 'diceroll', 'rnbvibes', 'festivefashion', 'holidaydecor', 'nbadraft', 'halloweenishere', 'christmas2020', 'howbizarre', 'sfxmakeup', 'givingthanks']

import pickle as pkl
import pymongo
client = MongoClient(port=27017)
from tokenizer import tokenize
db=client['tiktok']
ids={}
sequence_len = 20
embedding_matrix = np.load('E:\\data_pi\\embedding_matrix.npy')
embedding_matrix_norm = np.load('E:\\data_pi\\embedding_matrix_norm.npy')
word_index = pkl.load(open('E:\\data_pi\\word_index.pkl','rb'))
for ht in htlist:
    dayc=0
    ids[ht]={}
    for file in os.listdir('D:\\Work\\Tool\\tiktok\\TikToks\\'):
        if ht in file:
            dt=file.split('_')[2].replace('.json', '')
            tids=[]
            if dayc >= 14:
                break
            with open('D:\\Work\\Tool\\tiktok\\TikToks\\' + file, 'r', encoding='utf-8', newline='\n') as filename_input:
                lc=0
                for line in filename_input:
                    z = json.loads(line)
                    if lc >=1500:
                        break
                    k=''
                    if 'id' in z.keys():
                        k = z['id']
                    elif 'itemInfos' in z.keys():
                        k = z['itemInfos']['id']
                    if k not in tids :
                        tids.append(k)
                    lc+=1
            dayc+=1
            ids[ht][dt]={}
            if len(tids)<100:
                ids[ht][dt]['first'] =tids
                ids[ht][dt]['random'] = tids
            else:
                ids[ht][dt]['first'] = tids[:100]
                ids[ht][dt]['random'] = random.sample(tids,100)




            for seq in ['first','random']:
                idlist = []
                img = {}
                yamnet = {}
                texts = {}
                texts_sticker = {}
                labels = {}
                edits = {}
                htids = {}
                img_embed = {}

                for id in ids[ht][dt]['first']:
                    for obj in db[ht].find({'_id':id}):
                        if len(obj['video_feature']['img_embed']) > 0 and len(obj['text_feature']['text']) > 0 and len(
                            obj['video_feature']['audio']['yamnet']) > 0 and (
                                'var_sb' in obj['video_feature']['editing'].keys()) and (
                                'avg_sticker_length' in obj['video_feature']['editing'].keys()) and (
                                'avg_scences' in obj['video_feature']['editing'].keys())and (
                                'var_yamnet' in obj['video_feature']['editing'].keys()):
                            idlist.append(obj['_id'])
                            img_embed[obj['_id']] = []
                            maximg = obj['video_feature']['img_embed']['0'][4096:]
                            for key in obj['video_feature']['img_embed'].keys():
                                img_embed[obj['_id']].append(obj['video_feature']['img_embed'][key][:4096])
                                for i in range(4096, len(obj['video_feature']['img_embed'][key])):
                                    if maximg[i - 4096] < obj['video_feature']['img_embed'][key][i]:
                                        maximg[i - 4096] = obj['video_feature']['img_embed'][key][i]
                            img[obj['_id']] = maximg
                            yamnet[obj['_id']] = obj['video_feature']['audio']['yamnet']

                            temptext = obj['text_feature']['text']
                            temptext_sticker = ''
                            for item in obj['text_feature']['stickerText']:
                                temptext_sticker += item + ' '
                            texts[obj['_id']] = temptext
                            texts_sticker[obj['_id']] = temptext_sticker
                            # labels[obj['_id']]=[int(obj['video_feature']['label']['labelA']),int(obj['video_feature']['label']['labelB']),float(obj['video_feature']['residual']['residualA']),float(obj['video_feature']['residual']['residualB'])]

                            htids[obj['_id']] = ht
                            edits[obj['_id']] = [obj['video_feature']['editing']['video_len'],
                                                 obj['video_feature']['editing']['var_sb'],
                                                 obj['video_feature']['editing']['var_sb_c'],
                                                 obj['video_feature']['editing']['var_vgg'],
                                                 obj['video_feature']['editing']['var_vgg_c'],
                                                 obj['video_feature']['editing']['var_yamnet'],
                                                 obj['video_feature']['editing']['sticker_num'],
                                                 obj['video_feature']['editing']['avg_sticker_length'],
                                                 obj['video_feature']['editing']['avg_scences']]
                img_sequence_len = 13
                train_img_emb = []
                train_yamnet = []
                train_text = []
                train_text_embed = []
                train_text_embed_norm = []
                train_text_sticker = []
                train_text_embed_sticker = []
                train_text_embed_norm_sticker = []
                train_labels = []
                train_edits = []
                train_img_emb_new = []
                for item in idlist:
                    train_img_emb.append(img[item])
                    train_yamnet.append(yamnet[item])
                    words = tokenize(texts[item])
                    text = []
                    text_embed = []
                    text_embed_norm = []
                    for word in words[:sequence_len]:
                        if word in word_index:
                            text.append(word_index[word])
                        else:
                            continue
                        text_embed.append(embedding_matrix[text[-1]])
                        text_embed_norm.append(embedding_matrix_norm[text[-1]])
                    while len(text) < sequence_len:
                        text.append(0)
                        text_embed.append(np.zeros(embedding_matrix.shape[1]))
                        text_embed_norm.append(np.zeros(embedding_matrix_norm.shape[1]))

                    words_sticker = tokenize(texts_sticker[item])
                    text_sticker = []
                    text_embed_sticker = []
                    text_embed_norm_sticker = []
                    for word_sticker in words_sticker[:sequence_len]:
                        if word_sticker in word_index:
                            text_sticker.append(word_index[word_sticker])
                        else:
                            continue
                        text_embed_sticker.append(embedding_matrix[text_sticker[-1]])
                        text_embed_norm_sticker.append(embedding_matrix_norm[text_sticker[-1]])
                    while len(text_sticker) < sequence_len:
                        text_sticker.append(0)
                        text_embed_sticker.append(np.zeros(embedding_matrix.shape[1]))
                        text_embed_norm_sticker.append(np.zeros(embedding_matrix_norm.shape[1]))

                    image_embeds = img_embed[item]
                    im = []
                    for image in image_embeds[:img_sequence_len]:
                        im.append(image)
                    while len(im) < img_sequence_len:
                        im.append(np.zeros(4096))
                    train_img_emb_new.append(im)
                    train_text.append(text)
                    train_text_embed.append(text_embed)
                    train_text_embed_norm.append(text_embed_norm)

                    train_text_sticker.append(text_sticker)
                    train_text_embed_sticker.append(text_embed_sticker)
                    train_text_embed_norm_sticker.append(text_embed_norm_sticker)

                    train_edits.append(edits[item])
                train_img_emb = np.array(train_img_emb).astype(np.float32)
                train_yament = np.array(train_yamnet).astype(np.float32)
                train_img_emb_new = np.array(train_img_emb_new).astype(np.float32)
                train_edits = np.array(train_edits).astype(np.float32)
                np.save('E:\\simscore_ht\\image_embed_prob_' + ht+'_'+dt+'_'+seq, np.array(train_img_emb))
                np.save('E:\\simscore_ht\\yamnet_embed_' + ht+'_'+dt+'_'+seq, np.array(train_yamnet))
                np.save('E:\\simscore_ht\\text_' + ht+'_'+dt+'_'+seq, train_text)
                np.save('E:\\simscore_ht\\edit_embed_' + ht+'_'+dt+'_'+seq, train_edits)
                np.save('E:\\simscore_ht\\image_embed_' + ht+'_'+dt+'_'+seq, np.array(train_img_emb_new))

                np.save('E:\\simscore_ht\\text_embed_' + ht+'_'+dt+'_'+seq, train_text_embed)
                np.save('E:\\simscore_ht\\text_embed_norm_' + ht+'_'+dt+'_'+seq, train_text_embed_norm)

                np.save('E:\\simscore_ht\\text_sticker_' + ht+'_'+dt+'_'+seq, train_text_sticker)
                np.save('E:\\simscore_ht\\text_sticker_embed_' + ht+'_'+dt+'_'+seq, train_text_embed_sticker)
                np.save('E:\\simscore_ht\\text_sticker_embed_norm_' + ht+'_'+dt+'_'+seq, train_text_embed_norm_sticker)


#这个是准备data